import copy
import re

import numpy as np
from plotly import colors

from ..util import color_intervals, process_cmap

# Constants
# ---------

# Trace types that are individually positioned with their own domain.
# These are traces that don't overlay on top of each other in a shared subplot,
# so they are positioned individually.  All other trace types are associated
# with a layout subplot type (xaxis/yaxis, polar, scene etc.)
#
# Each of these trace types has a `domain` property with `x`/`y` properties
_domain_trace_types = {'parcoords', 'pie', 'table', 'sankey', 'parcats'}

# Subplot types that are each individually positioned with a domain
#
# Each of these subplot types has a `domain` property with `x`/`y` properties.
# Note that this set does not contain `xaxis`/`yaxis` because these behave a
# little differently.
_subplot_types = {'scene', 'geo', 'polar', 'ternary', 'mapbox'}

# For most subplot types, a trace is associated with a particular subplot
# using a trace property with a name that matches the subplot type. For
# example, a `scatter3d.scene` property set to `'scene2'` associates a
# scatter3d trace with the second `scene` subplot in the figure.
#
# There are a few subplot types that don't follow this pattern, and instead
# the trace property is just named `subplot`.  For example setting
# the `scatterpolar.subplot` property to `polar3` associates the scatterpolar
# trace with the third polar subplot in the figure
_subplot_prop_named_subplot = {'polar', 'ternary', 'mapbox'}

# Mapping from trace type to subplot type(s).
_trace_to_subplot = {
    # xaxis/yaxis
    'bar':                  ['xaxis', 'yaxis'],
    'box':                  ['xaxis', 'yaxis'],
    'candlestick':          ['xaxis', 'yaxis'],
    'carpet':               ['xaxis', 'yaxis'],
    'contour':              ['xaxis', 'yaxis'],
    'contourcarpet':        ['xaxis', 'yaxis'],
    'heatmap':              ['xaxis', 'yaxis'],
    'heatmapgl':            ['xaxis', 'yaxis'],
    'histogram':            ['xaxis', 'yaxis'],
    'histogram2d':          ['xaxis', 'yaxis'],
    'histogram2dcontour':   ['xaxis', 'yaxis'],
    'ohlc':                 ['xaxis', 'yaxis'],
    'pointcloud':           ['xaxis', 'yaxis'],
    'scatter':              ['xaxis', 'yaxis'],
    'scattercarpet':        ['xaxis', 'yaxis'],
    'scattergl':            ['xaxis', 'yaxis'],
    'violin':               ['xaxis', 'yaxis'],

    # scene
    'cone':         ['scene'],
    'mesh3d':       ['scene'],
    'scatter3d':    ['scene'],
    'streamtube':   ['scene'],
    'surface':      ['scene'],

    # geo
    'choropleth': ['geo'],
    'scattergeo': ['geo'],

    # polar
    'barpolar':         ['polar'],
    'scatterpolar':     ['polar'],
    'scatterpolargl':   ['polar'],

    # ternary
    'scatterternary': ['ternary'],

    # mapbox
    'scattermapbox': ['mapbox']
}

# Aliases - map common style options to more common names

STYLE_ALIASES = {'line_width': 'width', 'alpha': 'opacity',
                 'cell_height': 'height', 'marker': 'symbol'}

# Regular expression to extract any trailing digits from a subplot-style
# string.
_subplot_re = re.compile('\D*(\d+)')


def _get_subplot_number(subplot_val):
    """
    Extract the subplot number from a subplot value string.

    'x3' -> 3
    'polar2' -> 2
    'scene' -> 1
    'y' -> 1

    Note: the absence of a subplot number (e.g. 'y') is treated by plotly as
    a subplot number of 1

    Parameters
    ----------
    subplot_val: str
        Subplot string value (e.g. 'scene4')

    Returns
    -------
    int
    """
    match = _subplot_re.match(subplot_val)
    if match:
        subplot_number = int(match.group(1))
    else:
        subplot_number = 1
    return subplot_number


def _get_subplot_val_prefix(subplot_type):
    """
    Get the subplot value prefix for a subplot type. For most subplot types
    this is equal to the subplot type string itself. For example, a
    `scatter3d.scene` value of `scene2` is used to associate the scatter3d
    trace with the `layout.scene2` subplot.

    However, the `xaxis`/`yaxis` subplot types are exceptions to this pattern.
    For example, a `scatter.xaxis` value of `x2` is used to associate the
    scatter trace with the `layout.xaxis2` subplot.

    Parameters
    ----------
    subplot_type: str
        Subplot string value (e.g. 'scene4')

    Returns
    -------
    str
    """
    if subplot_type == 'xaxis':
        subplot_val_prefix = 'x'
    elif subplot_type == 'yaxis':
        subplot_val_prefix = 'y'
    else:
        subplot_val_prefix = subplot_type
    return subplot_val_prefix


def _get_subplot_prop_name(subplot_type):
    """
    Get the name of the trace property used to associate a trace with a
    particular subplot type.  For most subplot types this is equal to the
    subplot type string. For example, the `scatter3d.scene` property is used
    to associate a `scatter3d` trace with a particular `scene` subplot.

    However, for some subplot types the trace property is not named after the
    subplot type.  For example, the `scatterpolar.subplot` property is used
    to associate a `scatterpolar` trace with a particular `polar` subplot.


    Parameters
    ----------
    subplot_type: str
        Subplot string value (e.g. 'scene4')

    Returns
    -------
    str
    """
    if subplot_type in _subplot_prop_named_subplot:
        subplot_prop_name = 'subplot'
    else:
        subplot_prop_name = subplot_type
    return subplot_prop_name


def _normalize_subplot_ids(fig):
    """
    Make sure a layout subplot property is initialized for every subplot that
    is referenced by a trace in the figure.

    For example, if a figure contains a `scatterpolar` trace with the `subplot`
    property set to `polar3`, this function will make sure the figure's layout
    has a `polar3` property, and will initialize it to an empty dict if it
    does not

    Note: This function mutates the input figure dict

    Parameters
    ----------
    fig: dict
        A plotly figure dict
    """

    layout = fig.setdefault('layout', {})
    for trace in fig.get('data', None):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])
        for subplot_type in subplot_types:

            subplot_prop_name = _get_subplot_prop_name(subplot_type)
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)

            # extract trailing number (if any)
            subplot_number = _get_subplot_number(subplot_val)

            if subplot_number > 1:
                layout_prop_name = subplot_type + str(subplot_number)
            else:
                layout_prop_name = subplot_type

            if layout_prop_name not in layout:
                layout[layout_prop_name] = {}


def _get_max_subplot_ids(fig):
    """
    Given an input figure, return a dict containing the max subplot number
    for each subplot type in the figure

    Parameters
    ----------
    fig: dict
        A plotly figure dict

    Returns
    -------
    dict
        A dict from subplot type strings to integers indicating the largest
        subplot number in the figure of that subplot type
    """
    max_subplot_ids = {subplot_type: 0
                       for subplot_type in _subplot_types}
    max_subplot_ids['xaxis'] = 0
    max_subplot_ids['yaxis'] = 0

    for trace in fig.get('data', []):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])
        for subplot_type in subplot_types:

            subplot_prop_name = _get_subplot_prop_name(subplot_type)
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)

            # extract trailing number (if any)
            subplot_number = _get_subplot_number(subplot_val)

            max_subplot_ids[subplot_type] = max(
                max_subplot_ids[subplot_type], subplot_number)

    return max_subplot_ids


def _offset_subplot_ids(fig, offsets):
    """
    Apply offsets to the subplot id numbers in a figure.

    Note: This function mutates the input figure dict

    Note: This function assumes that the normalize_subplot_ids function has
    already been run on the figure, so that all layout subplot properties in
    use are explicitly present in the figure's layout.

    Parameters
    ----------
    fig: dict
        A plotly figure dict
    offsets: dict
        A dict from subplot types to the offset to be applied for each subplot
        type.  This dict matches the form of the dict returned by
        get_max_subplot_ids
    """
    # Offset traces
    for trace in fig.get('data', None):
        trace_type = trace.get('type', 'scatter')
        subplot_types = _trace_to_subplot.get(trace_type, [])

        for subplot_type in subplot_types:
            subplot_prop_name = _get_subplot_prop_name(subplot_type)

            # Compute subplot value prefix
            subplot_val_prefix = _get_subplot_val_prefix(subplot_type)
            subplot_val = trace.get(subplot_prop_name, subplot_val_prefix)
            subplot_number = _get_subplot_number(subplot_val)

            offset_subplot_number = (
                    subplot_number + offsets.get(subplot_type, 0))

            if offset_subplot_number > 1:
                trace[subplot_prop_name] = (
                        subplot_val_prefix + str(offset_subplot_number))
            else:
                trace[subplot_prop_name] = subplot_val_prefix

    # layout subplots
    layout = fig.setdefault('layout', {})
    new_subplots = {}

    for subplot_type in offsets:
        offset = offsets[subplot_type]
        if offset < 1:
            continue

        for layout_prop in list(layout.keys()):
            if layout_prop.startswith(subplot_type):
                subplot_number = _get_subplot_number(layout_prop)
                new_subplot_number = subplot_number + offset
                new_layout_prop = subplot_type + str(new_subplot_number)
                new_subplots[new_layout_prop] = layout.pop(layout_prop)

    layout.update(new_subplots)

    # xaxis/yaxis anchors
    x_offset = offsets.get('xaxis', 0)
    y_offset = offsets.get('yaxis', 0)

    for layout_prop in list(layout.keys()):
        if layout_prop.startswith('xaxis'):
            xaxis = layout[layout_prop]
            anchor = xaxis.get('anchor', 'y')
            anchor_number = _get_subplot_number(anchor) + y_offset
            if anchor_number > 1:
                xaxis['anchor'] = 'y' + str(anchor_number)
            else:
                xaxis['anchor'] = 'y'
        elif layout_prop.startswith('yaxis'):
            yaxis = layout[layout_prop]
            anchor = yaxis.get('anchor', 'x')
            anchor_number = _get_subplot_number(anchor) + x_offset
            if anchor_number > 1:
                yaxis['anchor'] = 'x' + str(anchor_number)
            else:
                yaxis['anchor'] = 'x'

    # annotations/shapes/images
    for layout_prop in ['annotations', 'shapes', 'images']:
        for obj in layout.get(layout_prop, []):
            if x_offset:
                xref = obj.get('xref', 'x')
                if xref != 'paper':
                    xref_number = _get_subplot_number(xref)
                    obj['xref'] = 'x' + str(xref_number + x_offset)

            if y_offset:
                yref = obj.get('yref', 'y')
                if yref != 'paper':
                    yref_number = _get_subplot_number(yref)
                    obj['yref'] = 'y' + str(yref_number + y_offset)


def _scale_translate(fig, scale_x, scale_y, translate_x, translate_y):
    """
    Scale a figure and translate it to sub-region of the original
    figure canvas.

    Note: If the input figure has a title, this title is converted into an
    annotation and scaled along with the rest of the figure.

    Note: This function mutates the input fig dict

    Note: This function assumes that the normalize_subplot_ids function has
    already been run on the figure, so that all layout subplot properties in
    use are explicitly present in the figure's layout.

    Parameters
    ----------
    fig: dict
        A plotly figure dict
    scale_x: float
        Factor by which to scale the figure in the x-direction. This will
        typically be a value < 1.  E.g. a value of 0.5 will cause the
        resulting figure to be half as wide as the original.
    scale_y: float
        Factor by which to scale the figure in the y-direction. This will
        typically be a value < 1
    translate_x: float
        Factor by which to translate the scaled figure in the x-direction in
        normalized coordinates.
    translate_y: float
        Factor by which to translate the scaled figure in the x-direction in
        normalized coordinates.
    """
    data = fig.setdefault('data', [])
    layout = fig.setdefault('layout', {})

    def scale_translate_x(x):
        return [x[0] * scale_x + translate_x,
                x[1] * scale_x + translate_x]

    def scale_translate_y(y):
        return [y[0] * scale_y + translate_y,
                y[1] * scale_y + translate_y]

    def perform_scale_translate(obj):
        domain = obj.setdefault('domain', {})
        x = domain.get('x', [0, 1])
        y = domain.get('y', [0, 1])

        domain['x'] = scale_translate_x(x)
        domain['y'] = scale_translate_y(y)

    # Scale/translate traces
    for trace in data:
        trace_type = trace.get('type', 'scatter')
        if trace_type in _domain_trace_types:
            perform_scale_translate(trace)

    # Scale/translate subplot containers
    for prop in layout:
        for subplot_type in _subplot_types:
            if prop.startswith(subplot_type):
                perform_scale_translate(layout[prop])

    for prop in layout:
        if prop.startswith('xaxis'):
            xaxis = layout[prop]
            x_domain = xaxis.get('domain', [0, 1])
            xaxis['domain'] = scale_translate_x(x_domain)
        elif prop.startswith('yaxis'):
            yaxis = layout[prop]
            y_domain = yaxis.get('domain', [0, 1])
            yaxis['domain'] = scale_translate_y(y_domain)

    # convert title to annotation
    # This way the annotation will be scaled with the reset of the figure
    annotations = layout.get('annotations', [])

    title = layout.pop('title', None)
    if title:
        titlefont = layout.pop('titlefont', {})
        title_fontsize = titlefont.get('size', 17)
        min_fontsize = 12
        titlefont['size'] = round(min_fontsize +
                                  (title_fontsize - min_fontsize) * scale_x)

        annotations.append({
            'text': title,
            'showarrow': False,
            'xref': 'paper',
            'yref': 'paper',
            'x': 0.5,
            'y': 1.01,
            'xanchor': 'center',
            'yanchor': 'bottom',
            'font': titlefont
        })
        layout['annotations'] = annotations

    # annotations
    for obj in layout.get('annotations', []):
        if obj.get('xref', None) == 'paper':
            obj['x'] = obj.get('x', 0.5) * scale_x + translate_x
            obj['y'] = obj.get('y', 0.5) * scale_y + translate_y


def merge_figure(fig, subfig):
    """
    Merge a sub-figure into a parent figure

    Note: This function mutates the input fig dict, but it does not mutate
    the subfig dict

    Parameters
    ----------
    fig: dict
        The plotly figure dict into which the sub figure will be merged
    subfig: dict
        The plotly figure dict that will be copied and then merged into `fig`
    """

    # traces
    data = fig.setdefault('data', [])
    data.extend(copy.deepcopy(subfig.get('data', [])))

    # layout
    layout = fig.setdefault('layout', {})
    _merge_layout_objs(layout, subfig.get('layout', {}))


def _merge_layout_objs(obj, subobj):
    """
    Merge layout objects recursively

    Note: This function mutates the input obj dict, but it does not mutate
    the subobj dict

    Parameters
    ----------
    obj: dict
        dict into which the sub-figure dict will be merged
    subobj: dict
        dict that sill be copied and merged into `obj`
    """
    for prop, val in subobj.items():
        if isinstance(val, dict) and prop in obj:
            # recursion
            _merge_layout_objs(obj[prop], val)
        elif (isinstance(val, list) and
              obj.get(prop, None) and
              isinstance(obj[prop][0], dict)):

            # append
            obj[prop].extend(val)
        else:
            # init/overwrite
            obj[prop] = copy.deepcopy(val)


def _compute_subplot_domains(widths, spacing):
    """
    Compute normalized domain tuples for a list of widths and a subplot
    spacing value

    Parameters
    ----------
    widths: list of float
        List of the desired withs of each subplot. The length of this list
        is also the specification of the number of desired subplots
    spacing: float
        Spacing between subplots in normalized coordinates

    Returns
    -------
    list of tuple of float
    """
    # normalize widths
    widths_sum = float(sum(widths))
    total_spacing = (len(widths) - 1) * spacing
    widths = [(w / widths_sum)*(1-total_spacing) for w in widths]
    domains = []

    for c in range(len(widths)):
        domain_start = c * spacing + sum(widths[:c])
        domain_stop = min(1, domain_start + widths[c])
        domains.append((domain_start, domain_stop))

    return domains


def figure_grid(figures_grid,
                row_heights=None,
                column_widths=None,
                row_spacing=0.15,
                column_spacing=0.15,
                share_xaxis=False,
                share_yaxis=False):
    """
    Construct a figure from a 2D grid of sub-figures

    Parameters
    ----------
    figures_grid: list of list of (dict or None)
        2D list of plotly figure dicts that will be combined in a grid to
        produce the resulting figure.  None values maybe used to leave empty
        grid cells
    row_heights: list of float (default None)
        List of the relative heights of each row in the grid (these values
        will be normalized by the function)
    column_widths: list of float (default None)
        List of the relative widths of each column in the grid (these values
        will be normalized by the function)
    row_spacing: float (default 0.15)
        Vertical spacing between rows in the gird in normalized coordinates
    column_spacing: float (default 0.15)
        Horizontal spacing between columns in the grid in normalized
        coordinates
    share_xaxis: bool (default False)
        Share x-axis between sub-figures in the same column. This will only
        work if each sub-figure has a single x-axis
    share_yaxis: bool (default False)
        Share y-axis between sub-figures in the same row. This will only work
        if each subfigure has a single y-axis

    Returns
    -------
    dict
        A plotly figure dict
    """

    # compute number of rows/cols
    rows = len(figures_grid)
    columns = len(figures_grid[0])

    # Initialize row heights / column widths
    if not row_heights:
        row_heights = [1 for _ in range(rows)]

    if not column_widths:
        column_widths = [1 for _ in range(columns)]

    # Compute domain widths/heights for subplots
    column_domains = _compute_subplot_domains(column_widths, column_spacing)
    row_domains = _compute_subplot_domains(row_heights, row_spacing)

    output_figure = {'data': [], 'layout': {}}

    for r, (fig_row, row_domain) in enumerate(zip(figures_grid, row_domains)):
        for c, (fig, column_domain) in enumerate(zip(fig_row, column_domains)):
            if fig:
                fig = copy.deepcopy(fig)

                _normalize_subplot_ids(fig)

                subplot_offsets = _get_max_subplot_ids(output_figure)

                if share_xaxis:
                    subplot_offsets['xaxis'] = c
                    if r != 0:
                        # Only use xaxes from bottom row
                        fig.get('layout', {}).pop('xaxis', None)

                if share_yaxis:
                    subplot_offsets['yaxis'] = r
                    if c != 0:
                        # Only use yaxes from first column
                        fig.get('layout', {}).pop('yaxis', None)

                _offset_subplot_ids(fig, subplot_offsets)

                scale_x = column_domain[1] - column_domain[0]
                scale_y = row_domain[1] - row_domain[0]
                _scale_translate(fig,
                                 scale_x, scale_y,
                                 column_domain[0], row_domain[0])

                merge_figure(output_figure, fig)

    return output_figure


def get_colorscale(cmap, levels=None, cmin=None, cmax=None):
    """Converts a cmap spec to a plotly colorscale

    Args:
        cmap: A recognized colormap by name or list of colors
        levels: A list or integer declaring the color-levels
        cmin: The lower bound of the color range
        cmax: The upper bound of the color range

    Returns:
        A valid plotly colorscale
    """
    ncolors = levels if isinstance(levels, int) else None
    if isinstance(levels, list):
        ncolors = len(levels) - 1
        if isinstance(cmap, list) and len(cmap) != ncolors:
            raise ValueError('The number of colors in the colormap '
                             'must match the intervals defined in the '
                             'color_levels, expected %d colors found %d.'
                             % (ncolors, len(cmap)))
    try:
        palette = process_cmap(cmap, ncolors)
    except Exception as e:
        colorscale = colors.PLOTLY_SCALES.get(cmap)
        if colorscale is None:
            raise e
        return colorscale

    if isinstance(levels, int):
        colorscale = []
        scale = np.linspace(0, 1, levels+1)
        for i in range(levels+1):
            if i == 0:
                colorscale.append((scale[0], palette[i]))
            elif i == levels:
                colorscale.append((scale[-1], palette[-1]))
            else:
                colorscale.append((scale[i], palette[i-1]))
                colorscale.append((scale[i], palette[i]))
        return colorscale
    elif isinstance(levels, list):
        palette, (cmin, cmax) = color_intervals(
            palette, levels, clip=(cmin, cmax))
    return colors.make_colorscale(palette)
